from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, ncols=1, figsize=(8, 16))
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
symrel_color = colors[3]
ppo_color = colors[0]

# ~~~~~ Pong PPO ~~~~~ #
pong_ppo_actions = pd.read_csv("recordings/performance-pong-ppo/pong-actions-ppo-list.csv")
orig = pong_ppo_actions.to_numpy()
new = []
for i, each_step in enumerate(orig):
    noop = True
    for j, each_action in enumerate(each_step):
        if each_action:
            new.append([i, j])
            noop = False
    if noop:
        new.append([i, 1])  # 1 stands for none
ax1.scatter(
    np.array(new)[:, 0],
    np.array(new)[:, 1],
    s=1.5,
    marker=",",
    label="PPO",
    c=ppo_color,
)
ax1.set_xlim([390, 810])
ax1.set_xticks([400, 800])
ax1.set_xticklabels(["400", "800"], fontsize=12)
ax1.set_yticks([0, 1, 2, 3, 4, 5, 6, 7])
ax1.set_yticklabels(["Button", "None", "Select", "Reset", "Up", "Down", "Left", "Right"], fontsize=10)
ax1.set_title("Pong - PPO Evaluation", fontsize=16)
ax1.set_ylabel("Action", fontsize=14)
ax1.grid(True, axis="y", alpha=0.3)

# ~~~~~ Pong SymRel ~~~~~ #
pong_symrel_actions = pd.read_csv("recordings/video-1-pong-atari2600/pong-actions-symrel-list.csv")
orig = pong_symrel_actions.to_numpy()
new = []
for i, each_step in enumerate(orig):
    noop = True
    for j, each_action in enumerate(each_step):
        if each_action:
            new.append([i, j])
            noop = False
    if noop:
        new.append([i, 1])  # 1 stands for none
ax2.scatter(
    np.array(new)[:, 0],
    np.array(new)[:, 1],
    s=1.5,
    marker=",",
    label="Distilled Symbolic Rule",
    c=symrel_color,
)
ax2.set_xlim([390, 810])
ax2.set_xticks([400, 800])
ax2.set_xticklabels(["400", "800"], fontsize=12)
ax2.set_yticks([0, 1, 2, 3, 4, 5, 6, 7])
ax2.set_yticklabels(["Button", "None", "Select", "Reset", "Up", "Down", "Left", "Right"], fontsize=10)
ax2.set_title("Pong - Distilled Policy Evaluation", fontsize=16)
ax2.set_ylabel("Action", fontsize=14)
ax2.grid(True, axis="y", alpha=0.3)

# ~~~~~ Seaquest PPO ~~~~~ #
# seaquest_ppo_actions = pd.read_csv("recordings/performance-seaquest/seaquest-actions-ppo-list.csv")
seaquest_ppo_actions = pd.read_csv("recordings/performance-seaquest/seaquest-actions-ppo-filtered-list.csv")
orig = seaquest_ppo_actions.to_numpy()
new = []
for i, each_step in enumerate(orig):
    noop = True
    for j, each_action in enumerate(each_step):
        if each_action:
            new.append([i, j])
            noop = False
    if noop:
        new.append([i, 1])  # 1 stands for none
ax3.scatter(
    np.array(new)[:, 0],
    np.array(new)[:, 1],
    s=1.5,
    marker=",",
    label="PPO",
    c=ppo_color,
)
ax3.set_xlim([390, 810])
ax3.set_xticks([400, 800])
ax3.set_xticklabels(["400", "800"], fontsize=12)
ax3.set_yticks([0, 1, 2, 3, 4, 5, 6, 7])
ax3.set_yticklabels(["Button", "None", "Select", "Reset", "Up", "Down", "Left", "Right"], fontsize=10)
ax3.set_title("Seaquest - PPO Evaluation", fontsize=16)
ax3.set_ylabel("Action", fontsize=14)
ax3.grid(True, axis="y", alpha=0.3)

# ~~~~~ Seaquest SymRel ~~~~~ #
seaquest_symrel_actions = pd.read_csv("recordings/performance-seaquest/seaquest-actions-symrel-list.csv")
orig = seaquest_symrel_actions.to_numpy()
new = []
for i, each_step in enumerate(orig):
    noop = True
    for j, each_action in enumerate(each_step):
        if each_action:
            new.append([i, j])
            noop = False
    if noop:
        new.append([i, 1])  # 1 stands for none
ax4.scatter(
    np.array(new)[:, 0],
    np.array(new)[:, 1],
    s=1.5,
    marker=",",
    label="Distilled Symbolic Rule",
    c=symrel_color,
)
ax4.set_xlim([390, 810])
ax4.set_xticks([400, 800])
ax4.set_xticklabels(["400", "800"], fontsize=12)
ax4.set_yticks([0, 1, 2, 3, 4, 5, 6, 7])
ax4.set_yticklabels(["Button", "None", "Select", "Reset", "Up", "Down", "Left", "Right"], fontsize=10)
ax4.set_title("Seaquest - Distilled Policy Evaluation", fontsize=16)
ax4.set_ylabel("Action", fontsize=14)
ax4.grid(True, axis="y", alpha=0.3)

# Show and save
fig.savefig("performance-plot_2.pdf", bbox_inches="tight")
